import pickle
import numpy as np
from inventory_simulator import InventorySimulator
from MDP_mp import MDP
import logging
import argparse
import os


def write_slurm_job_params_sample():
    num_samples = [-1, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000,
                   1000000, 2000000, 5000000, 10000000]
    with open("params.txt", "w") as f:
        for n in num_samples:
            f.write(f"{n} {-1}\n")


def write_slurm_job_params_inventory():
    with open("params.txt", "w") as f:
        x = np.linspace(0.02, 1, 50)
        for i in range(len(x)):
            f.write("{:.5f} {} \n".format(x[i], -1))


if __name__ == '__main__':
    # # ================================================================== #
    # # 🚀 sample
    # # ================================================================== #
    # write_slurm_job_params_sample()
    parser = argparse.ArgumentParser(description="inventory")
    parser.add_argument('-n', type=int, default=None)
    parser.add_argument('-run', type=int, default=None)
    args = parser.parse_args()
    args.n = None if args.n == -1 else args.n
    args.run = None if args.run == -1 else args.run
    print(f"n:{args.n}\nrun:{args.run}")
    save_path = "results/inventory/sample/"
    log_path = "logs/inventory/sample/"
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler(os.path.join(log_path, f"n={args.n}_run={args.run}.log")),  # 输出到文件
            logging.StreamHandler(),
        ]
    )
    if args.run is None:
        np.random.seed(args.n)
    else:
        np.random.seed(args.n + args.run)
    purchase_cost = 2
    delivery_cost = 0.0
    holding_cost = 0.2
    backlog_cost = 1
    sale_price = 3
    max_inventory = 10
    max_backlog = 5
    max_order = 5
    gamma = 0.9
    kappa = 1

    demands_prob = np.array([0.1, 0.2, 0.3, 0.3, 0.1])

    simulator = InventorySimulator(demands_prob, purchase_cost, delivery_cost, holding_cost, backlog_cost,
                                   sale_price, max_inventory, max_backlog, max_order)
    P, R = simulator.build_mdp(args.n)
    mdp = MDP(P, R, gamma, kappa, divergence='f', k=2, max_iteration=1000, eps=1e-8)
    v, pi = mdp.solve(rectangular='s')
    xi = mdp.xi

    results = {"params": (
        purchase_cost, delivery_cost, holding_cost, backlog_cost, sale_price, max_inventory, max_backlog, max_order,
        gamma, kappa, demands_prob),
        "n": args.n, "run": args.run, "v": v, "pi": pi, "xi": xi}

    with open(os.path.join(save_path, f"result_n={args.n}_run={args.run}.pkl"), "wb") as f:
        pickle.dump(results, f)
